{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Graph4NLP Demo: Math Word Problem\n",
    "\n",
    "---\n",
    "\n",
    "In this demo, we will have a closer look at how to apply **Graph2Tree model to the task of math word problem automatically solving**.\n",
    "Math word problem solving aims to infer reasonable equations from given natural language problem descriptions. It is important for exploring automatic solutions to mathematical problems and improving the reasoning ability of neural networks. \n",
    "In this demo, we use the Graph4NLP library to build a GNN-based math word problem (MWP) solving model. \n",
    "\n",
    "The **Graph2Tree** model consists of:\n",
    "\n",
    "- graph construction module (e.g., node embedding based dynamic graph)\n",
    "- graph embedding module (e.g., undirected GraphSage)\n",
    "- predictoin module (e.g., tree decoder with attention and copy mechanisms)\n",
    "\n",
    "As shown in the picture below, we firstly construct graph input from problem description by syntactic parsing (CoreNLP) and then represent the output equation with a hierarchical structure (Node ``N`` stands for non-terminal node).\n",
    "\n",
    "<p align=\"center\">\n",
    "<img src=\"./imgs/g2t.png\" width=\"600\" class=\"center\" alt=\"graph2tree_mwp\"/>\n",
    "    <br/>\n",
    "</p>\n",
    "\n",
    "We will use the built-in Graph2Tree model APIs to build the model, and evaluate it on the Mawps dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Environment setup\n",
    "---\n",
    "\n",
    "Please follow the instructions [here](https://github.com/graph4ai/graph4nlp_demo#environment-setup) to set up the environment. Please also run the following commands to install extra packages used in this demo.\n",
    "```\n",
    "pip install sympy\n",
    "pip install ipywidgets\n",
    "```\n",
    "\n",
    "\n",
    "This notebook was tested on :\n",
    "\n",
    "```\n",
    "torch == 1.9.0\n",
    "torchtext == 0.10.0\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-22T13:51:01.626460Z",
     "start_time": "2021-07-22T13:51:01.615750Z"
    }
   },
   "source": [
    "## Load the config file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T10:13:55.123979Z",
     "start_time": "2021-07-27T10:13:55.092616Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'batch_size': 20,\n",
      " 'beam_size': 4,\n",
      " 'checkpoint_save_path': './checkpoint_save',\n",
      " 'dataset_yaml': './config.yaml',\n",
      " 'decoder_args': {'rnn_decoder_private': {'max_decoder_step': 35,\n",
      "                                          'max_tree_depth': 8,\n",
      "                                          'use_input_feed': True,\n",
      "                                          'use_sibling': False},\n",
      "                  'rnn_decoder_share': {'attention_type': 'uniform',\n",
      "                                        'dropout': 0.3,\n",
      "                                        'fuse_strategy': 'concatenate',\n",
      "                                        'graph_pooling_strategy': None,\n",
      "                                        'hidden_size': 300,\n",
      "                                        'input_size': 300,\n",
      "                                        'rnn_emb_input_size': 300,\n",
      "                                        'rnn_type': 'lstm',\n",
      "                                        'teacher_forcing_rate': 1.0,\n",
      "                                        'use_copy': True,\n",
      "                                        'use_coverage': False}},\n",
      " 'decoder_name': 'stdtree',\n",
      " 'gpuid': -1,\n",
      " 'grad_clip': 5,\n",
      " 'graph_construction_args': {'graph_construction_private': {'as_node': False,\n",
      "                                                            'edge_strategy': 'homogeneous',\n",
      "                                                            'merge_strategy': 'tailhead',\n",
      "                                                            'sequential_link': True},\n",
      "                             'graph_construction_share': {'graph_type': 'dependency',\n",
      "                                                          'port': 9000,\n",
      "                                                          'root_dir': './data',\n",
      "                                                          'share_vocab': True,\n",
      "                                                          'thread_number': 4,\n",
      "                                                          'timeout': 15000,\n",
      "                                                          'topology_subdir': 'DependencyGraph'},\n",
      "                             'node_embedding': {'connectivity_ratio': 0.05,\n",
      "                                                'embedding_style': {'bert_lower_case': None,\n",
      "                                                                    'bert_model_name': None,\n",
      "                                                                    'emb_strategy': 'w2v_bilstm',\n",
      "                                                                    'num_rnn_layers': 1,\n",
      "                                                                    'single_token_item': True},\n",
      "                                                'epsilon_neigh': 0.5,\n",
      "                                                'fix_bert_emb': False,\n",
      "                                                'fix_word_emb': False,\n",
      "                                                'hidden_size': 300,\n",
      "                                                'input_size': 300,\n",
      "                                                'num_heads': 1,\n",
      "                                                'rnn_dropout': 0.1,\n",
      "                                                'sim_metric_type': 'weighted_cosine',\n",
      "                                                'smoothness_ratio': 0.1,\n",
      "                                                'sparsity_ratio': 0.1,\n",
      "                                                'top_k_neigh': None,\n",
      "                                                'word_dropout': 0.1}},\n",
      " 'graph_construction_name': 'dependency',\n",
      " 'graph_embedding_args': {'graph_embedding_private': {'activation': 'relu',\n",
      "                                                      'aggregator_type': 'lstm',\n",
      "                                                      'bias': True,\n",
      "                                                      'norm': None,\n",
      "                                                      'use_edge_weight': False},\n",
      "                          'graph_embedding_share': {'attn_drop': 0.0,\n",
      "                                                    'direction_option': 'undirected',\n",
      "                                                    'feat_drop': 0.0,\n",
      "                                                    'hidden_size': 300,\n",
      "                                                    'input_size': 300,\n",
      "                                                    'num_layers': 1,\n",
      "                                                    'output_size': 300}},\n",
      " 'graph_embedding_name': 'graphsage',\n",
      " 'graph_type': 'static',\n",
      " 'init_weight': 0.08,\n",
      " 'learning_rate': 0.001,\n",
      " 'max_epochs': 20,\n",
      " 'min_freq': 1,\n",
      " 'pretrained_word_emb_cache_dir': '.vector_cache',\n",
      " 'pretrained_word_emb_name': None,\n",
      " 'pretrained_word_emb_url': None,\n",
      " 'seed': 123,\n",
      " 'share_vocab': True,\n",
      " 'weight_decay': 0}\n"
     ]
    }
   ],
   "source": [
    "from graph4nlp.pytorch.modules.config import get_basic_args\n",
    "from graph4nlp.pytorch.modules.utils.config_utils import update_values, get_yaml_config\n",
    "\n",
    "def get_args():\n",
    "    config = {'dataset_yaml': \"./config.yaml\",\n",
    "              'learning_rate': 1e-3,\n",
    "              'gpuid': -1,\n",
    "              'seed': 123, \n",
    "              'init_weight': 0.08,\n",
    "              'graph_type': 'static',\n",
    "              'weight_decay': 0, \n",
    "              'max_epochs': 20, \n",
    "              'min_freq': 1,\n",
    "              'grad_clip': 5,\n",
    "              'batch_size': 20,\n",
    "              'share_vocab': True,\n",
    "              'pretrained_word_emb_name': None,\n",
    "              'pretrained_word_emb_url': None,\n",
    "              'pretrained_word_emb_cache_dir': \".vector_cache\",\n",
    "              'checkpoint_save_path': \"./checkpoint_save\",\n",
    "              'beam_size': 4\n",
    "              }\n",
    "    our_args = get_yaml_config(config['dataset_yaml'])\n",
    "    template = get_basic_args(graph_construction_name=our_args[\"graph_construction_name\"],\n",
    "                              graph_embedding_name=our_args[\"graph_embedding_name\"],\n",
    "                              decoder_name=our_args[\"decoder_name\"])\n",
    "    update_values(to_args=template, from_args_list=[our_args, config])\n",
    "    return template\n",
    "\n",
    "# show our config\n",
    "cfg_g2t = get_args()\n",
    "from pprint import pprint\n",
    "pprint(cfg_g2t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T10:13:55.452276Z",
     "start_time": "2021-07-27T10:13:55.364294Z"
    }
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "import torch\n",
    "import random\n",
    "import argparse\n",
    "import numpy as np\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from graph4nlp.pytorch.data.data import to_batch\n",
    "from graph4nlp.pytorch.datasets.mawps import MawpsDatasetForTree\n",
    "from graph4nlp.pytorch.modules.graph_construction import DependencyBasedGraphConstruction\n",
    "from graph4nlp.pytorch.modules.graph_embedding import *\n",
    "from graph4nlp.pytorch.models.graph2tree import Graph2Tree\n",
    "from graph4nlp.pytorch.modules.utils.tree_utils import Tree\n",
    "\n",
    "from utils import convert_to_string, compute_tree_accuracy, prepare_oov"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T10:13:55.669860Z",
     "start_time": "2021-07-27T10:13:55.640456Z"
    }
   },
   "outputs": [],
   "source": [
    "class Mawps:\n",
    "    def __init__(self, opt=None):\n",
    "        super(Mawps, self).__init__()\n",
    "        self.opt = opt\n",
    "\n",
    "        seed = self.opt[\"seed\"]\n",
    "        random.seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "\n",
    "        if self.opt[\"gpuid\"] == -1:\n",
    "            self.device = torch.device(\"cpu\")\n",
    "        else:\n",
    "            self.device = torch.device(\"cuda:{}\".format(self.opt[\"gpuid\"]))\n",
    "\n",
    "        self.use_copy = self.opt[\"decoder_args\"][\"rnn_decoder_share\"][\"use_copy\"]\n",
    "        self.use_share_vocab = self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"share_vocab\"]\n",
    "        self.data_dir = self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"root_dir\"]\n",
    "\n",
    "        self._build_dataloader()\n",
    "        self._build_model()\n",
    "        self._build_optimizer()\n",
    "\n",
    "    def _build_dataloader(self):\n",
    "        para_dic =  {'root_dir': self.data_dir,\n",
    "                    'word_emb_size': self.opt[\"graph_construction_args\"][\"node_embedding\"][\"input_size\"],\n",
    "                    'topology_builder': DependencyBasedGraphConstruction,\n",
    "                    'topology_subdir': self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"topology_subdir\"], \n",
    "                    'edge_strategy': self.opt[\"graph_construction_args\"][\"graph_construction_private\"][\"edge_strategy\"],\n",
    "                    'graph_type': 'static',\n",
    "                    'dynamic_graph_type': self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"graph_type\"], \n",
    "                    'share_vocab': self.use_share_vocab, \n",
    "                    'enc_emb_size': self.opt[\"graph_construction_args\"][\"node_embedding\"][\"input_size\"],\n",
    "                    'dec_emb_size': self.opt[\"decoder_args\"][\"rnn_decoder_share\"][\"input_size\"],\n",
    "                    'dynamic_init_topology_builder': None,\n",
    "                    'min_word_vocab_freq': self.opt[\"min_freq\"],\n",
    "                    'pretrained_word_emb_name': self.opt[\"pretrained_word_emb_name\"],\n",
    "                    'pretrained_word_emb_url': self.opt[\"pretrained_word_emb_url\"], \n",
    "                    'pretrained_word_emb_cache_dir': self.opt[\"pretrained_word_emb_cache_dir\"]\n",
    "                    }\n",
    "\n",
    "        dataset = MawpsDatasetForTree(**para_dic)\n",
    "\n",
    "        self.train_data_loader = DataLoader(dataset.train, batch_size=self.opt[\"batch_size\"], shuffle=True,\n",
    "                                            num_workers=0,\n",
    "                                            collate_fn=dataset.collate_fn)\n",
    "        self.test_data_loader = DataLoader(dataset.test, batch_size=1, shuffle=False, num_workers=0,\n",
    "                                           collate_fn=dataset.collate_fn)\n",
    "        self.valid_data_loader = DataLoader(dataset.val, batch_size=1, shuffle=False, num_workers=0,\n",
    "                                          collate_fn=dataset.collate_fn)\n",
    "        self.vocab_model = dataset.vocab_model\n",
    "        self.src_vocab = self.vocab_model.in_word_vocab\n",
    "        self.tgt_vocab = self.vocab_model.out_word_vocab\n",
    "        self.share_vocab = self.vocab_model.share_vocab if self.use_share_vocab else None\n",
    "\n",
    "    def _build_model(self):\n",
    "        '''For encoder-decoder'''\n",
    "        self.model = Graph2Tree.from_args(self.opt, \n",
    "                                          vocab_model=self.vocab_model)\n",
    "        self.model.init(self.opt[\"init_weight\"])\n",
    "        self.model.to(self.device)\n",
    "\n",
    "    def _build_optimizer(self):\n",
    "        optim_state = {\"learningRate\": self.opt[\"learning_rate\"], \"weight_decay\": self.opt[\"weight_decay\"]}\n",
    "        parameters = [p for p in self.model.parameters() if p.requires_grad]\n",
    "        self.optimizer = optim.Adam(parameters, lr=optim_state['learningRate'], weight_decay=optim_state['weight_decay'])\n",
    "\n",
    "    def train_epoch(self, epoch):\n",
    "        loss_to_print = 0\n",
    "        num_batch = len(self.train_data_loader)\n",
    "        for step, data in tqdm(enumerate(self.train_data_loader), desc=f'Epoch {epoch:02d}', total=len(self.train_data_loader)):\n",
    "            batch_graph, batch_tree_list, batch_original_tree_list = data['graph_data'], data['dec_tree_batch'], data['original_dec_tree_batch']\n",
    "            batch_graph = batch_graph.to(self.device)\n",
    "            self.optimizer.zero_grad()\n",
    "            oov_dict = prepare_oov(\n",
    "                batch_graph, self.src_vocab, self.device) if self.use_copy else None\n",
    "\n",
    "            if self.use_copy:\n",
    "                batch_tree_list_refined = []\n",
    "                for item in batch_original_tree_list:\n",
    "                    tgt_list = oov_dict.get_symbol_idx_for_list(item.strip().split())\n",
    "                    tgt_tree = Tree.convert_to_tree(tgt_list, 0, len(tgt_list), oov_dict)\n",
    "                    batch_tree_list_refined.append(tgt_tree)\n",
    "            loss = self.model(batch_graph, batch_tree_list_refined if self.use_copy else batch_tree_list, oov_dict=oov_dict)\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_value_(\n",
    "                self.model.parameters(), self.opt[\"grad_clip\"])\n",
    "            self.optimizer.step()\n",
    "            loss_to_print += loss\n",
    "        return loss_to_print/num_batch\n",
    "\n",
    "    def train(self):\n",
    "        best_acc = -1\n",
    "        best_model = None\n",
    "\n",
    "        print(\"-------------\\nStarting training.\")\n",
    "        for epoch in range(1, self.opt[\"max_epochs\"]+1):\n",
    "            self.model.train()\n",
    "            loss_to_print = self.train_epoch(epoch)\n",
    "            print(\"epochs = {}, train_loss = {:.3f}\".format(epoch, loss_to_print))\n",
    "            if epoch > 15:\n",
    "                val_acc = self.eval(self.model, mode=\"val\")\n",
    "                if val_acc > best_acc:\n",
    "                    best_acc = val_acc\n",
    "                    best_model = self.model\n",
    "        self.eval(best_model, mode=\"test\")\n",
    "        best_model.save_checkpoint(self.opt[\"checkpoint_save_path\"], \"best.pt\")\n",
    "\n",
    "    def eval(self, model, mode=\"val\"):\n",
    "        model.eval()\n",
    "        reference_list = []\n",
    "        candidate_list = []\n",
    "        data_loader = self.test_data_loader if mode == \"test\" else self.valid_data_loader\n",
    "        for data in tqdm(data_loader, desc=\"Eval: \"):\n",
    "            eval_input_graph, batch_tree_list, batch_original_tree_list = data['graph_data'], data['dec_tree_batch'], data['original_dec_tree_batch']\n",
    "            eval_input_graph = eval_input_graph.to(self.device)\n",
    "            oov_dict = prepare_oov(eval_input_graph, self.src_vocab, self.device)\n",
    "\n",
    "            if self.use_copy:\n",
    "                assert len(batch_original_tree_list) == 1\n",
    "                reference = oov_dict.get_symbol_idx_for_list(batch_original_tree_list[0].split())\n",
    "                eval_vocab = oov_dict\n",
    "            else:\n",
    "                assert len(batch_original_tree_list) == 1\n",
    "                reference = model.tgt_vocab.get_symbol_idx_for_list(batch_original_tree_list[0].split())\n",
    "                eval_vocab = self.tgt_vocab\n",
    "            \n",
    "            candidate = model.translate(eval_input_graph,\n",
    "                                        oov_dict=oov_dict,\n",
    "                                        use_beam_search=True,\n",
    "                                        beam_size=self.opt[\"beam_size\"])\n",
    "            \n",
    "            candidate = [int(c) for c in candidate]\n",
    "            num_left_paren = sum(\n",
    "                1 for c in candidate if eval_vocab.idx2symbol[int(c)] == \"(\")\n",
    "            num_right_paren = sum(\n",
    "                1 for c in candidate if eval_vocab.idx2symbol[int(c)] == \")\")\n",
    "            diff = num_left_paren - num_right_paren\n",
    "            if diff > 0:\n",
    "                for i in range(diff):\n",
    "                    candidate.append(\n",
    "                        self.test_data_loader.tgt_vocab.symbol2idx[\")\"])\n",
    "            elif diff < 0:\n",
    "                candidate = candidate[:diff]\n",
    "            ref_str = convert_to_string(\n",
    "                reference, eval_vocab)\n",
    "            cand_str = convert_to_string(\n",
    "                candidate, eval_vocab)\n",
    "\n",
    "            reference_list.append(reference)\n",
    "            candidate_list.append(candidate)\n",
    "        eval_acc = compute_tree_accuracy(\n",
    "            candidate_list, reference_list, eval_vocab)\n",
    "        print(\"{} accuracy = {:.3f}\\n\".format(mode, eval_acc))\n",
    "        return eval_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T09:12:35.052975Z",
     "start_time": "2021-07-27T09:12:34.827787Z"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T10:13:58.783006Z",
     "start_time": "2021-07-27T10:13:57.875582Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/hugo/opt/anaconda3/envs/g4l_202108/lib/python3.7/site-packages/torch-1.9.0-py3.7-macosx-10.9-x86_64.egg/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n"
     ]
    }
   ],
   "source": [
    "a = Mawps(cfg_g2t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T10:14:07.766389Z",
     "start_time": "2021-07-27T10:13:59.534775Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------\n",
      "Starting training.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "897857db3e1a4b87a20fd431bbffa875",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 01:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/hugo/opt/anaconda3/envs/g4l_202108/lib/python3.7/site-packages/torch-1.9.0-py3.7-macosx-10.9-x86_64.egg/torch/nn/functional.py:1794: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n",
      "  warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epochs = 1, train_loss = 24.041\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "68aee1fe06534291adbf3c3c591b2310",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 02:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epochs = 2, train_loss = 17.632\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ed60cb4ad71945dcbf99cc13d8a03976",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 03:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epochs = 3, train_loss = 15.362\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3e26e73ed8641deb2057f713d847f63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 04:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epochs = 4, train_loss = 14.293\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "45f36e8cdb3c4ab5acd343cab7d50ea8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 05:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epochs = 5, train_loss = 13.127\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dfdbb2a9f8c6404ca7701ce1f9c33d19",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 06:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epochs = 6, train_loss = 11.714\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c55a4b4e04ab42ed84c09459629ed98b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 07:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epochs = 7, train_loss = 10.822\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2031b4fdb76d40309a432a2941c6be6f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 08:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epochs = 8, train_loss = 10.217\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c54fc9b60c354f028bd3499db454474b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 09:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epochs = 9, train_loss = 9.805\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5aac2e0a284a425a96ef87280bbbaad9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 10:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epochs = 10, train_loss = 9.400\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d4eeb8c14f414dbda71f9935325daf0a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 11:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epochs = 11, train_loss = 9.254\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ad96e08d6c34c4db500718875869a13",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch 12:   0%|          | 0/94 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "best_acc = a.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "53b113626f4385b7a52334a3b11ec5d9307ad80c73f59f759f44504bc95f0ff2"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
